function [y U1 Obj] = StructuredAdaPurG(data, labels, alpha, beta, knn)
% [res, S, Obj] = InStructuredGL(data, labels, alpha, beta, knn)
%
% Code for the following problem:
% min_{U, S, F, C_i} sum_i((x_i-x_j)^2 s_i^v + \gamma s_i^2) + \sum_i||U - C^i||_F + \sum_{ij}w_{ij}\mu^i\mu^jTr((A^i - C^i).(A^j - C^j))
% s.t. s1 = 1, S >= 0, rank(L) = n-c, A^i >= C^i >= 0
%
% ATTN1: This package is free for academic usage. The code was developed by Dr. Shudong Huang (huangsd@scu.edu.cn). You can run
% it at your own risk. For other purposes, please contact Prof. Jiancheng Lv (lvjiancheng@scu.edu.cn)
%
% ATTN2: This package was developed by Dr. Huang (huangsd@scu.edu.cn). For any problem concerning the code, please feel
% free to contact Dr. Huang.
%
% Inputs:
%   data - a cell array, view_num*1, each array is n*d_v
%   numC - number of clusters
%   view_num - number of views
%   numSmp - number of samples (instances)
%   knn - number of adaptive neighbours
%   labels - a column vector, groundtruth of the data, num by 1
%   knn - number of k-nearest neighbors (set knn=0 if using fully connected graph)
%   alpha, beta  - hyperparameters for the algorithm
%  Optional Inputs:
%   tol, tol2 - the tolerance that determines convergence of algorithm
%
% Outputs:
%   res - clustering results (normalized mutual information, ACC, Purity)
%   label - label generated by spectral clustering on the learned unified graph
%   U: target structured graph with explicit cluster structure
%   E - a cell matrix containing the inconsistent part of all views
%   C - a cell matrix containing the consistent part of all views
%
% Written by: Shudong Huang (huangsd@scu.edu.cn)
% 2020/07/15
%
% figure()
% [U1,PS] = mapminmax(U,0,1);
% imshow(U1) 
% colormap('jet');
% or:
% imshow(U1,'InitialMagnification','fit')
% colormap('jet');
%
% Example: 
% [res, S, Obj] = StructuredAdaPurG(data, labels, 1e4, 1e-4);

if nargin < 3
    alpha = 1;
end
if nargin < 4
    beta = 1;
end
if nargin < 5
    knn = 10;
end
% number of data samples
numSmp = length(labels);
% number of clusters
numC = length(unique(labels));
% number of views
numView = length(data);
% initialize \mu
mu = ones(numView,1) / numView;
% initialize \lambda
lambda = randperm(10,1);
% === Normalization ===
for i = 1:numView
    for j = 1:numSmp
        normItem = std(data{i}(j,:));
        if (0 == normItem)
            normItem = eps;
        end;
        data{i}(j,:) = (data{i}(j,:) - mean( data{i}(j,:))) / normItem;
    end;
end;
% === Initialization === 
S0 = cell(numView,1);
for i = 1:numView
    [S0{i}, ~] = InitializeSIGs(data{i}', knn, 0);
end;

% initialize U, F and w
U = zeros(numSmp);
for i = 1:numView
    U = U + S0{i};
end;
U = U/numView;
for j = 1:numSmp
    U(j,:) = U(j,:)/sum(U(j,:));
end;

sU = (U+U')/2;
D = diag(sum(sU));
L = D - sU;
[F, ~, evs] = eig1(L, numC, 0);
U = sU;

idxx = cell(1,numView);
ed = cell(1,numView);
ed1 = cell(1,numView);
for v = 1:numView
    ed{v} = L2_distance_1(data{v}', data{v}');
    % [ed1{v}, idxx{v}] = sort(ed{v}, 2); % sort each row
end;





% consider the setting of fully connected graph 
knn_idx = true(numSmp);
up_knn_idx = triu(knn_idx);
% ... update ... %
ITER = 30;
zr = 10e-10;

S = S0;

for iter = 1:ITER
    %
    % update C^{(i)}
        % coefficient matrix
    W = alpha*ones(numView) - diag(alpha*ones(1,numView)) + diag(beta*ones(1,numView));
    M = W.*(mu*mu') + diag(mu);
    %
    commom_baA = zeros(numSmp);
    for v = 1:numView 
        baA{v} = alpha*mu(v)*S{v};
        special_baA{v} = beta*mu(v)*S{v};
        commom_baA = commom_baA + baA{v};
    end
    for v = 1:numView 
        true_baA{v} = commom_baA - baA{v} + special_baA{v};
        temp = full(mu(v)*(U + true_baA{v})); % 
        P{v} = temp(up_knn_idx);
    end
    right_p = cat(2, P{:})';
    if det(M) == 0
        solution = (pinv(M) * right_p)';
        fprintf('------------')
    else
        solution = (M \ right_p)';
    end
    solution(solution<0) = 0;
    for v = 1:numView
        temp = solution(:,v);
        % C_old = C{v};
        C{v} = zeros(numSmp);
        C{v}(up_knn_idx) = temp;
        C{v} = max(C{v}, C{v}');
        C{v} = min(C{v}, S{v});
%         if sparse_mode
%             C{v} = sparse(C{v});
%         end
        % C_change = C_change + norm(C_old - C{v}, 'fro');
    end
    
    % update S^{(i)}
    
    commom_baE = zeros(numSmp);
    for v = 1:numView  
        baE{v} = alpha*mu(v)*(S{v}-C{v});
        special_baE{v} = beta*mu(v)*(S{v}-C{v});
        commom_baE = commom_baE + baE{v};
    end
    for v = 1:numView 
        true_baE{v} = commom_baE - baE{v} + special_baE{v};
        temp = full(mu(v)*true_baE{v}); % 
        zeta{v} = max(temp, temp');
    end
    for v = 1:numView 
        ed1{v} = ed{v} + zeta{v};
        [~, idxx{v}] = sort(ed1{v}, 2); % sort each row
        % [ed2{v}, idxx{v}] = sort((ed1{v} + zeta{v}), 2);
    end
    for v = 1:numView
        S0{v} = zeros(numSmp);
        for i = 1:numSmp                 
            id = idxx{v}(i,2:knn+2);
            di = ed1{v}(i, id);
            % di = ed2(i,2:k+2);
            numerator = di(knn+1)-di;
            denominator1 = knn*di(knn+1)-sum(di(1:knn));
            rr(i) = 0.5*denominator1;
            S0{v}(i,id) = max(numerator/(denominator1+eps),0);           
        end;
        gamma(v) = mean(rr);
%         for j = 1:num
%             normItem = sum(S0{v}(j,:));
%             if normItem == 0
%                 normItem = eps;
%             end;
%             S0{v}(j,:) = S0{v}(j,:)/normItem;
%         end;

    end;
    S = S0;
    
     
    % update U
    U = updateU(C, F, mu, lambda, numSmp, numView);
    
    % update \mu
    for v = 1:numView
        distUC = norm(U - C{v},'fro')^2;
        if distUC == 0
            distUC = eps;
        end;
        mu(v) = 0.5/sqrt(distUC);
    end

    % update F
    U1 = U;
    U1 = (U1 + U1')/2;
    D = diag(sum(U1));
    L = D - U1;
        % store F temporaly
    F_old = F; 
    [F, ~, ev]=eig1(L, numC, 0);
    
    % calculate obj
    obj1 = 0;
    for v = 1:numView
        E{v} = S{v} - C{v};
        obj1 = obj1 + norm(U - C{v},'fro');
    end
    obj2 = 0;
    for i = 1:numView
        for j = i:numView
            obj2 = obj2 + W(i,j)*mu(i)*mu(j)*trace(E{i}*E{j}'); % sum(sum(E{i}.*E{j}));
        end
    end
    obj3 = 0;
    for v = 1:numView
        for j = 1:numSmp
            for j = 1:numSmp
                obj3 = obj3 + ed{v}(i,j)*S{v}(i,j) + gamma(v)*(S{v}(i,j)^2);
            end
        end
    end
    obj4 = 0;
    obj4 = obj4 + 2*lambda*trace(F'*L*F);
    % Obj(iter) = obj1 + obj2;
    Obj(iter) = obj1 + obj2 + obj3 + obj4;  
    
    % update \lambda
    fn1 = sum(ev(1:numC));
    fn2 = sum(ev(1:numC + 1));
    if fn1 > zr
        lambda = lambda*2;
    elseif fn2 < zr
        lambda = lambda/2;
        F = F_old;
    else
        fprintf('------------ \n')
        break;
    end;
end

[clusternum, y]=graphconncomp(sparse(U1)); 
y = y';
if clusternum ~= numC
    sprintf('Can not find the correct cluster number: %d', numC)
end;
% [ACC, MIhat, Purity] = ClusteringMeasure(labels, y);
% res = EvaluationMetrics(labels, y);
end


function [S, D] = InitializeSIGs(X, k, issymmetric)
% X: each column is a data point
% k: number of neighbors
% issymmetric: set S = (S+S')/2 if issymmetric=1
% S: similarity matrix, each row is a data point
% Ref: F. Nie, X. Wang, M. I. Jordan, and H. Huang, The constrained
% Laplacian rank algorithm for graph-based clustering, in AAAI, 2016.

if nargin < 3
    issymmetric = 1;
end;
if nargin < 2
    k = 5;
end;

[~, n] = size(X);
D = L2_distance_1(X, X);
[~, idx] = sort(D, 2); % sort each row

S = zeros(n);
for i = 1:n
    id = idx(i,2:k+2);
    di = D(i, id);
    S(i,id) = (di(k+1)-di)/(k*di(k+1)-sum(di(1:k))+eps);
end;

if issymmetric == 1
    S = (S+S')/2;
end;
end


function U = updateU(C, F, mu, lambda, numSmp, numView)
    dist = L2_distance_1(F',F');
    U = zeros(numSmp);
    for i=1:numSmp
        c0 = zeros(1,numSmp);
        for v = 1:numView
            temp = C{v};
            c0 = c0 + mu(v)*temp(i,:);
        end     
        idxa0 = find(c0>0);
        ci = c0(idxa0);
        ui = dist(i,idxa0);
        cu = (ci - 0.5*lambda*ui)/sum(mu);
        U(i,idxa0) = EProjSimplex_new(cu);
    end;
end


